import numpy as np


def csvd(A):
    """计算A的紧奇异值分解

    :param A: 目标矩阵
    :return:
        m×r阶矩阵（完全奇异值分解的前r列）
        r阶对角矩阵（完全奇异值分解的前r个对角线元素）
        n×r阶矩阵的转置（完全奇异值分解的前r行）
    """

    # 计算对称矩阵 W = A^T*A
    W = np.dot(A.T, A)

    # 求解矩阵的特征值和特征向量：返回的特征值是升序的、特征向量已经被单位化
    # (因为已知W是对称矩阵，所以使用eigh而不是eig)
    l, V = np.linalg.eigh(W)

    # 计算特征值的平方根并由大到小排列
    l = np.sqrt(l[::-1])

    # ---------- 2.求r×r的对角矩阵S的剪辑后结果 ----------
    # 构造r×r的对角矩阵S
    S = np.diag([v for v in l if v > 0])

    # 计算矩阵的秩
    r = S.shape[0]

    # ---------- 3.求n阶正交矩阵V的剪辑后结果 ----------
    # * 因为特征向量对应的特征值已经升序，所以直接翻转即可
    # * 特征向量已经被单位化
    V = V[:, -1:-1 - r:-1]

    # ---------- 4.求m阶正交矩阵U的剪辑后结果 ----------
    U = np.hstack([(np.dot(A, V[:, i]) / l[i])[:, np.newaxis] for i in range(r)])

    return U, S, V.T


if __name__ == "__main__":
    A = np.array([[1, 1],
                  [2, 2],
                  [0, 0]])

    U, S, VT = csvd(A)

    print(U)
    # [[0.4472136 ]
    #  [0.89442719]
    #  [0.        ]]

    print(S)
    # [[3.16227766]]

    print(VT)
    # [[0.70710678 0.70710678]]
